{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Word2vec basic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Packages loaded\n"
     ]
    }
   ],
   "source": [
    "import collections\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "import zipfile\n",
    "import numpy as np\n",
    "from six.moves import urllib\n",
    "from six.moves import xrange  # pylint: disable=redefined-builtin\n",
    "from sklearn.manifold import TSNE\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "%matplotlib inline  \n",
    "print (\"Packages loaded\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Download the text and make corpus (set of words)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download (or reuse) the text file that we will use "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "File already exists\n"
     ]
    }
   ],
   "source": [
    "folder_dir  = \"data\"\n",
    "file_name   = \"text8.zip\"\n",
    "file_path   = os.path.join(folder_dir, file_name)\n",
    "url         = 'http://mattmahoney.net/dc/'\n",
    "if not os.path.exists(file_path):\n",
    "    print (\"No file found. Start downloading\")\n",
    "    downfilename, _ = urllib.request.urlretrieve(\n",
    "        url + file_name, file_path)\n",
    "    print (\"'%s' downloaded\" % (downfilename))\n",
    "else:\n",
    "    print (\"File already exists\") "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check we have correct data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "I guess we have correct file at 'data/text8.zip'\n"
     ]
    }
   ],
   "source": [
    "statinfo = os.stat(file_path)\n",
    "expected_bytes = 31344016\n",
    "if statinfo.st_size == expected_bytes:\n",
    "    print (\"I guess we have correct file at '%s'\" % (file_path))\n",
    "else:\n",
    "    print (\"Something's wrong with the file at '%s'\" % (file_path))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Unzip the file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def read_data(filename):\n",
    "    with zipfile.ZipFile(filename) as f:\n",
    "        data = f.read(f.namelist()[0]).split()\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Type of 'words' is <type 'list'> / Length is 17005207 \n",
      "'words' look like \n",
      " ['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse', 'first', 'used', 'against', 'early', 'working', 'class', 'radicals', 'including', 'the', 'diggers', 'of', 'the', 'english', 'revolution', 'and', 'the', 'sans', 'culottes', 'of', 'the', 'french', 'revolution', 'whilst', 'the', 'term', 'is', 'still', 'used', 'in', 'a', 'pejorative', 'way', 'to', 'describe', 'any', 'act', 'that', 'used', 'violent', 'means', 'to', 'destroy', 'the', 'organization', 'of', 'society', 'it', 'has', 'also', 'been', 'taken', 'up', 'as', 'a', 'positive', 'label', 'by', 'self', 'defined', 'anarchists', 'the', 'word', 'anarchism', 'is', 'derived', 'from', 'the', 'greek', 'without', 'archons', 'ruler', 'chief', 'king', 'anarchism', 'as', 'a', 'political', 'philosophy', 'is', 'the', 'belief', 'that', 'rulers', 'are', 'unnecessary', 'and', 'should', 'be', 'abolished', 'although', 'there', 'are', 'differing']\n"
     ]
    }
   ],
   "source": [
    "words = read_data(file_path) \n",
    "print (\"Type of 'words' is %s / Length is %d \" \n",
    "       % (type(words), len(words)))\n",
    "print (\"'words' look like \\n %s\" %(words[0:100]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 2. Make a dictionary with fixed length (using UNK token)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Count the words "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Type of 'count' is <type 'list'> / Length is 50000 \n",
      "'count' looks like \n",
      " [['UNK', -1], ('the', 1061396), ('of', 593677), ('and', 416629), ('one', 411764), ('in', 372201), ('a', 325873), ('to', 316376), ('zero', 264975), ('nine', 250430)]\n"
     ]
    }
   ],
   "source": [
    "vocabulary_size = 50000 \n",
    "count = [['UNK', -1]] \n",
    "count.extend(collections.Counter(words)\n",
    "             .most_common(vocabulary_size - 1)) # -1 is for UNK \n",
    "print (\"Type of 'count' is %s / Length is %d \" % (type(count), len(count)))\n",
    "print (\"'count' looks like \\n %s\" % (count[0:10]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make a dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Type of 'dictionary' is <type 'dict'> / Length is 50000 \n"
     ]
    }
   ],
   "source": [
    "dictionary = dict() \n",
    "for word, _ in count:\n",
    "    dictionary[word] = len(dictionary)\n",
    "print (\"Type of 'dictionary' is %s / Length is %d \" \n",
    "       % (type(dictionary), len(dictionary))) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Make a reverse dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Type of 'reverse_dictionary' is <type 'dict'> / Length is 50000 \n"
     ]
    }
   ],
   "source": [
    "reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys()))\n",
    "print (\"Type of 'reverse_dictionary' is %s / Length is %d \" \n",
    "       % (type(reverse_dictionary), len(reverse_dictionary))) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data = list()\n",
    "unk_count = 0\n",
    "for word in words:\n",
    "    if word in dictionary:\n",
    "        index = dictionary[word]\n",
    "    else:\n",
    "        index = 0  # dictionary['UNK']\n",
    "        unk_count += 1\n",
    "    data.append(index)\n",
    "count[0][1] = unk_count\n",
    "# del words  # Hint to reduce memory."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 'dictionary' converts word to index \n",
    "### 'reverse_dictionary' converts index to word "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Most common words (+UNK) are: [['UNK', 418391], ('the', 1061396), ('of', 593677), ('and', 416629), ('one', 411764)]\n"
     ]
    }
   ],
   "source": [
    "print (\"Most common words (+UNK) are: %s\" % (count[:5]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data (in indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample data: [5239, 3084, 12, 6, 195, 2, 3137, 46, 59, 156]\n"
     ]
    }
   ],
   "source": [
    "print (\"Sample data: %s\" % (data[:10])) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convert to char (which we can read)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample data corresponds to\n",
      "__________________\n",
      "5239->anarchism\n",
      "3084->originated\n",
      "12->as\n",
      "6->a\n",
      "195->term\n",
      "2->of\n",
      "3137->abuse\n",
      "46->first\n",
      "59->used\n",
      "156->against\n"
     ]
    }
   ],
   "source": [
    "print (\"Sample data corresponds to\\n__________________\")\n",
    "for i in range(10):\n",
    "    print (\"%d->%s\" % (data[i], reverse_dictionary[data[i]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Batch-generating function for skip-gram model\n",
    "## - Skip-gram (one word to one word) => Can generate more training data\n",
    "\n",
    "<img src=\"images/etc/word2vec_desc.png\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data_index = 0\n",
    "def generate_batch(batch_size, num_skips, skip_window):\n",
    "    global data_index\n",
    "    assert batch_size % num_skips == 0\n",
    "    assert num_skips <= 2 * skip_window\n",
    "    batch  = np.ndarray(shape=(batch_size),    dtype=np.int32)\n",
    "    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)\n",
    "    span = 2 * skip_window + 1 # [ skip_window target skip_window ]\n",
    "    buffer = collections.deque(maxlen=span)\n",
    "    for _ in range(span):\n",
    "        buffer.append(data[data_index])\n",
    "        data_index = (data_index + 1) % len(data)\n",
    "    for i in range(batch_size // num_skips): # '//' makes the result an integer, e.g., 7//3 = 2\n",
    "        target = skip_window\n",
    "        targets_to_avoid = [ skip_window ]\n",
    "        for j in range(num_skips):\n",
    "            while target in targets_to_avoid:\n",
    "                target = random.randint(0, span - 1)\n",
    "            targets_to_avoid.append(target)\n",
    "            batch[i * num_skips + j] = buffer[skip_window]\n",
    "            labels[i * num_skips + j, 0] = buffer[target]\n",
    "        buffer.append(data[data_index])\n",
    "        data_index = (data_index + 1) % len(data)\n",
    "    return batch, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Examples for generating batch and labels "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Type of 'batch' is <type 'numpy.ndarray'> / Length is 8 \n",
      "Type of 'labels' is <type 'numpy.ndarray'> / Length is 8 \n"
     ]
    }
   ],
   "source": [
    "data_index = 0\n",
    "batch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1)\n",
    "print (\"Type of 'batch' is %s / Length is %d \" \n",
    "       % (type(batch), len(batch))) \n",
    "print (\"Type of 'labels' is %s / Length is %d \" \n",
    "       % (type(labels), len(labels))) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'batch' looks like \n",
      " [3084 3084   12   12    6    6  195  195]\n"
     ]
    }
   ],
   "source": [
    "print (\"'batch' looks like \\n %s\" % (batch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'labels' looks like \n",
      " [[5239]\n",
      " [  12]\n",
      " [3084]\n",
      " [   6]\n",
      " [ 195]\n",
      " [  12]\n",
      " [   2]\n",
      " [   6]]\n"
     ]
    }
   ],
   "source": [
    "print (\"'labels' looks like \\n %s\" % (labels)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3084 -> 5239 \toriginated -> anarchism\n",
      "3084 -> 12 \toriginated -> as\n",
      "12 -> 3084 \tas -> originated\n",
      "12 -> 6 \tas -> a\n",
      "6 -> 195 \ta -> term\n",
      "6 -> 12 \ta -> as\n",
      "195 -> 2 \tterm -> of\n",
      "195 -> 6 \tterm -> a\n"
     ]
    }
   ],
   "source": [
    "for i in range(8):\n",
    "    print (\"%d -> %d\" \n",
    "           % (batch[i], labels[i, 0])),\n",
    "    print (\"\\t%s -> %s\" \n",
    "           % (reverse_dictionary[batch[i]]\n",
    "              , reverse_dictionary[labels[i, 0]])) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Build a Skip-Gram Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameters ready\n"
     ]
    }
   ],
   "source": [
    "batch_size     = 128\n",
    "embedding_size = 128       # Dimension of the embedding vector.\n",
    "skip_window    = 1         # How many words to consider left and right.\n",
    "num_skips      = 2         # How many times to reuse an input \n",
    "print (\"Parameters ready\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 31  89  64  57 122 178 173  13  50 176  12 188   0  19 168  44  59 130\n",
      " 126  75 187 161  10 194 119 131   5  69 150 165   6  70]\n"
     ]
    }
   ],
   "source": [
    "# Random validation set to sample nearest neighbors.\n",
    "valid_size     = 32        # Random set of words to evaluate similarity \n",
    "valid_window   = 200       # Only pick validation samples in the top 200\n",
    "valid_examples = np.random.choice(valid_window, valid_size, replace=False)\n",
    "\n",
    "print (valid_examples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network ready\n"
     ]
    }
   ],
   "source": [
    "# Construct the word2vec model \n",
    "train_inputs   = tf.placeholder(tf.int32, shape=[batch_size])   \n",
    "train_labels   = tf.placeholder(tf.int32, shape=[batch_size, 1])\n",
    "valid_dataset  = tf.constant(valid_examples, dtype=tf.int32)\n",
    "\n",
    "# Look up embeddings for inputs. (vocabulary_size = 50,000)\n",
    "with tf.variable_scope(\"EMBEDDING\"):\n",
    "    with tf.device('/cpu:0'):\n",
    "        embeddings = tf.Variable(\n",
    "            tf.random_uniform([vocabulary_size, embedding_size]\n",
    "                              , -1.0, 1.0))\n",
    "        embed = tf.nn.embedding_lookup(embeddings, train_inputs)\n",
    "    \n",
    "# Construct the variables for the NCE loss\n",
    "with tf.variable_scope(\"NCE_WEIGHT\"):\n",
    "    nce_weights = tf.Variable(\n",
    "        tf.truncated_normal([vocabulary_size, embedding_size],\n",
    "                            stddev=1.0 / math.sqrt(embedding_size)))\n",
    "    nce_biases = tf.Variable(tf.zeros([vocabulary_size]))\n",
    "print (\"Network ready\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Functions Ready\n"
     ]
    }
   ],
   "source": [
    "with tf.device('/cpu:0'):\n",
    "    # Loss function \n",
    "    num_sampled = 64        # Number of negative examples to sample. \n",
    "    loss = tf.reduce_mean(\n",
    "        tf.nn.nce_loss(nce_weights, nce_biases, embed\n",
    "                       , train_labels, num_sampled, vocabulary_size))\n",
    "    # Optimizer\n",
    "    optm = tf.train.GradientDescentOptimizer(1.0).minimize(loss)\n",
    "    # Similarity measure (important)\n",
    "    norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))\n",
    "    normalized_embeddings = embeddings / norm\n",
    "    valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings\n",
    "                    , valid_dataset)\n",
    "    siml = tf.matmul(valid_embeddings, normalized_embeddings\n",
    "                    , transpose_b=True)\n",
    "    \n",
    "print (\"Functions Ready\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4. Train a Skip-Gram Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss at step 0 is 0.141\n",
      "Nearest to 'an': 'yorktown', 'fux', 'katydids', 'leary', 'interruption', 'detox',\n",
      "Nearest to 'called': 'wong', 'germaine', 'electrophilic', 'kilpatrick', 'alkenes', 'retained',\n",
      "Nearest to 'american': 'mq', 'reef', 'guang', 'seas', 'puma', 'rq',\n",
      "Nearest to 'who': 'scarred', 'directory', 'spate', 'denison', 'incensed', 'adipose',\n",
      "Nearest to 'c': 'chapman', 'charlize', 'lading', 'ethnographers', 'stingray', 'vainly',\n",
      "Nearest to 'set': 'emptive', 'menelik', 'headed', 'unsubstantiated', 'churchyard', 'horticultural',\n",
      "Nearest to 'different': 'unreleased', 'arameans', 'psychologist', 'earmarked', 'phylum', 'ayer',\n",
      "Nearest to 'eight': 'specificity', 'leftover', 'fulfil', 'urgell', 'oak', 'columba',\n",
      "Nearest to 'all': 'transport', 'linspire', 'eo', 'donatello', 'georgi', 'picturesque',\n",
      "Nearest to 'do': 'megabits', 'condorcet', 'landau', 'coincidental', 'ransom', 'proper',\n",
      "Nearest to 'as': 'yesterday', 'hw', 'solace', 'modeled', 'hats', 'conant',\n",
      "Nearest to 'law': 'gleaned', 'fickle', 'expressive', 'gory', 'diem', 'threatens',\n",
      "Nearest to 'UNK': 'probate', 'guests', 'determines', 'dave', 'ubiquity', 'inks',\n",
      "Nearest to 'by': 'parton', 'mccann', 'mahoney', 'diasporas', 'benefiting', 'nieuwe',\n",
      "Nearest to 'large': 'towers', 'multiculturalism', 'endowment', 'enhancing', 'brainwashing', 'eia',\n",
      "Nearest to 'their': 'ecc', 'somber', 'quaternary', 'schopenhauer', 'trombone', 'brethren',\n",
      "Nearest to 'used': 'shrunk', 'fixtures', 'ultraviolet', 'volley', 'hab', 'spiced',\n",
      "Nearest to 'british': 'equinox', 'shandy', 'socialism', 'revert', 'zeeland', 'progs',\n",
      "Nearest to 'x': 'callous', 'envoys', 'gamecube', 'unintelligent', 'rosicrucian', 'participated',\n",
      "Nearest to 'd': 'francesco', 'schumacher', 'creighton', 'tubas', 'indulgence', 'dip',\n",
      "Nearest to 'king': 'circumscribed', 'niacin', 'inter', 'bluescreen', 'catapult', 'farnese',\n",
      "Nearest to 'based': 'oj', 'odi', 'fubar', 'jogaila', 'enables', 'refrigerated',\n",
      "Nearest to 'two': 'orbits', 'hostel', 'skeleton', 'factories', 'borghese', 'panned',\n",
      "Nearest to 'international': 'da', 'confidential', 'hiv', 'rhythmic', 'incipient', 'talkie',\n",
      "Nearest to 'him': 'reelection', 'trivia', 'counters', 'lsch', 'visual', 'seaport',\n",
      "Nearest to 'year': 'overwhelming', 'astaire', 'conclude', 'donner', 'honky', 'objects',\n",
      "Nearest to 'in': 'events', 'golem', 'elbert', 'lung', 'rsc', 'coordinated',\n",
      "Nearest to 'may': 'sur', 'formalized', 'inasmuch', 'telekinetic', 'bisexual', 'duchamp',\n",
      "Nearest to 'g': 'aruban', 'africans', 'jason', 'perutz', 'stun', 'hahn',\n",
      "Nearest to 'de': 'compactification', 'brugge', 'disarming', 'plotted', 'env', 'fielder',\n",
      "Nearest to 'a': 'ally', 'motorola', 'concrete', 'catcher', 'necks', 'mortimer',\n",
      "Nearest to 'than': 'karts', 'gwynedd', 'hesse', 'agglutinative', 'there', 'ensues',\n",
      "Average loss at step 2000 is 114.071\n",
      "Average loss at step 4000 is 52.629\n",
      "Average loss at step 6000 is 33.005\n",
      "Average loss at step 8000 is 23.590\n",
      "Average loss at step 10000 is 17.952\n",
      "Nearest to 'an': 'the', 'overseas', 'protein', 'interruption', 'victoriae', 'mya',\n",
      "Nearest to 'called': 'gland', 'within', 'gb', 'much', 'retained', 'consists',\n",
      "Nearest to 'american': 'olympics', 'methadone', 'linguistic', 'technologically', 'while', 'spider',\n",
      "Nearest to 'who': 'and', 'directory', 'named', 'attacked', 'mode', 'ruth',\n",
      "Nearest to 'c': 'victoriae', 'accounts', 'mathbf', 'reginae', 'austin', 'passion',\n",
      "Nearest to 'set': 'reginae', 'shortcuts', 'animals', 'few', 'members', 'majority',\n",
      "Nearest to 'different': 'mya', 'unreleased', 'cl', 'phylum', 'psychologist', 'reginae',\n",
      "Nearest to 'eight': 'nine', 'zero', 'six', 'reginae', 'agave', 'vs',\n",
      "Nearest to 'all': 'transport', 'beginning', 'bestseller', 'cl', 'waugh', 'lindbergh',\n",
      "Nearest to 'do': 'vs', 'proper', 'assistant', 'symbols', 'reflect', 'christian',\n",
      "Nearest to 'as': 'in', 'for', 'and', 'by', 'irish', 'asterism',\n",
      "Nearest to 'law': 'collapse', 'cc', 'soviet', 'accordion', 'badges', 'kyoto',\n",
      "Nearest to 'UNK': 'and', 'alpina', 'mengele', 'the', 'reginae', 'one',\n",
      "Nearest to 'by': 'in', 'and', 'as', 'at', 'gland', 'was',\n",
      "Nearest to 'large': 'towers', 'gland', 'fricatives', 'conducted', 'greatest', 'charlie',\n",
      "Nearest to 'their': 'reginae', 'schopenhauer', 'afghani', 'the', 'sabotage', 'implicit',\n",
      "Nearest to 'used': 'spiced', 'mystery', 'mathbf', 'sacred', 'mining', 'fixtures',\n",
      "Nearest to 'british': 'anatomy', 'socialism', 'equinox', 'victoriae', 'astor', 'five',\n",
      "Nearest to 'x': 'phi', 'right', 'participated', 'modern', 'countries', 'video',\n",
      "Nearest to 'd': 'finalist', 'avoid', 'tubing', 'vs', 'francesco', 'schumacher',\n",
      "Nearest to 'king': 'inter', 'thompson', 'negligible', 'moves', 'alpina', 'wisconsin',\n",
      "Nearest to 'based': 'reginae', 'chemist', 'dim', 'directory', 'traditionally', 'agave',\n",
      "Nearest to 'two': 'one', 'reginae', 'vs', 'austin', 'gollancz', 'three',\n",
      "Nearest to 'international': 'da', 'hiv', 'confidential', 'access', 'six', 'robots',\n",
      "Nearest to 'him': 'infectious', 'trivia', 'spontaneous', 'seaport', 'directors', 'visual',\n",
      "Nearest to 'year': 'conclude', 'objects', 'overwhelming', 'clear', 'responsibility', 'vs',\n",
      "Nearest to 'in': 'and', 'of', 'mya', 'with', 'on', 'by',\n",
      "Nearest to 'may': 'sur', 'cannot', 'tubing', 'formalized', 'individualist', 'faure',\n",
      "Nearest to 'g': 'eight', 'aruban', 'jason', 'africans', 'i', 'am',\n",
      "Nearest to 'de': 'publicly', 'oxus', 'prize', 'shortly', 'course', 'aa',\n",
      "Nearest to 'a': 'the', 'austin', 'UNK', 'and', 'reginae', 'alpina',\n",
      "Nearest to 'than': 'karts', 'community', 'hesse', 'there', 'racial', 'caves',\n",
      "Average loss at step 12000 is 13.873\n",
      "Average loss at step 14000 is 11.914\n",
      "Average loss at step 16000 is 9.810\n",
      "Average loss at step 18000 is 8.702\n",
      "Average loss at step 20000 is 7.841\n",
      "Nearest to 'an': 'the', 'protein', 'overseas', 'their', 'interruption', 'counterexample',\n",
      "Nearest to 'called': 'within', 'gland', 'hiroshima', 'litigants', 'dasyprocta', 'clan',\n",
      "Nearest to 'american': 'and', 'methadone', 'olympics', 'while', 'technologically', 'certainty',\n",
      "Nearest to 'who': 'and', 'attacked', 'also', 'ruth', 'named', 'agouti',\n",
      "Nearest to 'c': 'eight', 'victoriae', 'accounts', 'one', 'blog', 'chapman',\n",
      "Nearest to 'set': 'reginae', 'shortcuts', 'astoria', 'headquarters', 'members', 'few',\n",
      "Nearest to 'different': 'unreleased', 'mya', 'cl', 'agouti', 'psychologist', 'truetype',\n",
      "Nearest to 'eight': 'nine', 'six', 'zero', 'five', 'seven', 'three',\n",
      "Nearest to 'all': 'bestseller', 'beginning', 'transport', 'waugh', 'cl', 'lindbergh',\n",
      "Nearest to 'do': 'condorcet', 'vs', 'proper', 'assistant', 'reflect', 'symbols',\n",
      "Nearest to 'as': 'and', 'for', 'by', 'was', 'is', 'in',\n",
      "Nearest to 'law': 'agouti', 'threatens', 'collapse', 'clem', 'sumer', 'soviet',\n",
      "Nearest to 'UNK': 'agouti', 'alpina', 'victoriae', 'reginae', 'dasyprocta', 'and',\n",
      "Nearest to 'by': 'was', 'and', 'in', 'from', 'as', 'with',\n",
      "Nearest to 'large': 'towers', 'aquila', 'enhancing', 'gland', 'fricatives', 'additions',\n",
      "Nearest to 'their': 'the', 'reginae', 'his', 'afghani', 'a', 'this',\n",
      "Nearest to 'used': 'spiced', 'mystery', 'haliotis', 'mining', 'ultraviolet', 'algol',\n",
      "Nearest to 'british': 'equinox', 'anatomy', 'astor', 'socialism', 'victoriae', 'vivian',\n",
      "Nearest to 'x': 'actaeon', 'envoys', 'four', 'agouti', 'apatosaurus', 'eight',\n",
      "Nearest to 'd': 'b', 'and', 'francesco', 'finalist', 'asphyxia', 'one',\n",
      "Nearest to 'king': 'circumscribed', 'inter', 'niacin', 'iphigenia', 'four', 'moves',\n",
      "Nearest to 'based': 'chemist', 'charlie', 'encryption', 'reginae', 'competitions', 'dim',\n",
      "Nearest to 'two': 'one', 'three', 'five', 'nine', 'six', 'four',\n",
      "Nearest to 'international': 'da', 'UNK', 'hiv', 'confidential', 'aloe', 'length',\n",
      "Nearest to 'him': 'seaport', 'spontaneous', 'infectious', 'trivia', 'directors', 'marine',\n",
      "Nearest to 'year': 'conclude', 'five', 'overwhelming', 'responsibility', 'objects', 'subkey',\n",
      "Nearest to 'in': 'and', 'with', 'of', 'on', 'mya', 'from',\n",
      "Nearest to 'may': 'sur', 'cannot', 'can', 'seleucid', 'agouti', 'tubing',\n",
      "Nearest to 'g': 'agouti', 'aruban', 'eight', 'jason', 'i', 'dasyprocta',\n",
      "Nearest to 'de': 'fita', 'marischal', 'publicly', 'shortly', 'electrophoresis', 'winner',\n",
      "Nearest to 'a': 'the', 'agouti', 'reginae', 'victoriae', 'and', 'afghani',\n",
      "Nearest to 'than': 'community', 'karts', 'hesse', 'caves', 'eight', 'ensues',\n",
      "Average loss at step 22000 is 7.241\n"
     ]
    }
   ],
   "source": [
    "# Train! \n",
    "sess = tf.Session()\n",
    "sess.run(tf.initialize_all_variables())\n",
    "summary_writer = tf.train.SummaryWriter('/tmp/tf_logs/word2vec', graph=sess.graph)\n",
    "average_loss = 0\n",
    "\n",
    "num_steps = 100001\n",
    "for iter in xrange(num_steps):\n",
    "    batch_inputs, batch_labels = generate_batch(batch_size, num_skips, skip_window)\n",
    "    feed_dict = {train_inputs : batch_inputs, train_labels : batch_labels}\n",
    "    _, loss_val = sess.run([optm, loss], feed_dict=feed_dict)\n",
    "    average_loss += loss_val\n",
    "    \n",
    "    if iter % 2000 == 0:\n",
    "        average_loss /= 2000\n",
    "        print (\"Average loss at step %d is %.3f\" % (iter, average_loss)) \n",
    "    \n",
    "    if iter % 10000 == 0:\n",
    "        siml_val = sess.run(siml)\n",
    "        for i in xrange(valid_size): # Among valid set \n",
    "            valid_word = reverse_dictionary[valid_examples[i]]\n",
    "            top_k = 6 # number of nearest neighbors\n",
    "            nearest = (-siml_val[i, :]).argsort()[1:top_k+1]\n",
    "            log_str = \"Nearest to '%s':\" % valid_word\n",
    "            for k in xrange(top_k):\n",
    "                close_word = reverse_dictionary[nearest[k]] \n",
    "                log_str = \"%s '%s',\" % (log_str, close_word)\n",
    "            print(log_str) \n",
    "            \n",
    "# Final embeding \n",
    "final_embeddings = sess.run(normalized_embeddings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5. Visualize the embeding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def plot_with_labels(low_dim_embs, labels, filename='tsne.png'):\n",
    "    assert low_dim_embs.shape[0] >= len(labels), \"More labels than embeddings\"\n",
    "    plt.figure(figsize=(18, 18))  #in inches\n",
    "    for i, label in enumerate(labels):\n",
    "        x, y = low_dim_embs[i,:]\n",
    "        plt.scatter(x, y)\n",
    "        plt.annotate(label,\n",
    "                 xy=(x, y),\n",
    "                 xytext=(5, 2),\n",
    "                 textcoords='offset points',\n",
    "                 ha='right',\n",
    "                 va='bottom')\n",
    "    plt.show()\n",
    "    \n",
    "# Plot\n",
    "tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)\n",
    "plot_only = 500\n",
    "low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only,:])\n",
    "labels = [reverse_dictionary[i] for i in xrange(plot_only)]\n",
    "plot_with_labels(low_dim_embs, labels)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
